{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run evaluation on different action policies, e.g. VLA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from VLABench.evaluation.evaluator import Evaluator\n",
    "from VLABench.evaluation.model.policy.openvla import OpenVLA\n",
    "from VLABench.evaluation.model.policy.base import RandomPolicy\n",
    "from VLABench.tasks import *\n",
    "from VLABench.robots import *\n",
    "\n",
    "demo_tasks = [\"select_fruit\"]\n",
    "unseen = True\n",
    "save_dir = \"/home/shiduo/project/VLABench/logs\"\n",
    "\n",
    "model_ckpt = \"/remote-home1/pjliu/openvla-7b\"\n",
    "lora_ckpt = \"/remote-home1/pjliu/openvla/weights/select_fruit+CSv1+lora/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Init evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = Evaluator(\n",
    "    tasks=demo_tasks,\n",
    "    n_episodes=2,\n",
    "    max_substeps=10,   \n",
    "    save_dir=save_dir,\n",
    "    visulization=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load basic random policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_policy = RandomPolicy(model=None)\n",
    "result = evaluator.evaluate(random_policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load policies, take OpenVLA as example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = OpenVLA(\n",
    "    model_ckpt=model_ckpt,\n",
    "    lora_ckpt=lora_ckpt,\n",
    "    norm_config_file=os.path.join(os.getenv(\"VLABENCH_ROOT\"), \"configs/model/openvla_config.json\")\n",
    ")\n",
    "\n",
    "result = evaluator.evaluate(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run evaluation on different VLMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from VLABench.evaluation.model.vlm import *\n",
    "from VLABench.evaluation.evaluator import VLMEvaluator\n",
    "\n",
    "vlm_name = \"GPT_4v\" # valid names: [\"GPT_4v\", \"Qwen2_VL\", \"InternVL2\", \"MiniCPM_V2_6\", \"GLM4v\", \"Llava_NeXT\"]\n",
    "fewshot_num = 0\n",
    "task_list = [\"mesh_and_texture/select_fruit\"]\n",
    "\n",
    "def initialize_model(model_name, *args, **kwargs):\n",
    "    cls = globals().get(model_name)\n",
    "    if cls is None:\n",
    "        raise ValueError(f\"Model '{model_name}' not found in the current namespace.\")\n",
    "    \n",
    "    return cls(*args, **kwargs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vlm = initialize_model(vlm_name)\n",
    "evaluator = VLMEvaluator(\n",
    "    tasks=task_list,\n",
    "    n_episodes=2,\n",
    "    data_path=os.path.join(os.getenv(\"VLABENCH_ROOT\"), \"../dataset\", \"vlm\"),\n",
    "    save_path=os.path.join(os.getenv(\"VLABENCH_ROOT\"), \"../logs/vlm\"),\n",
    ")\n",
    "\n",
    "evaluator.evaluate(vlm, few_shot_num=fewshot_num)\n",
    "result=evaluator.get_final_score_dict(vlm_name)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vlabench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
